//ControlStatusRegisterUnit.scala
package mycore

import chisel3._
import chisel3.util._
import difftest._

class ControlStatusRegisterUnit extends Module{
  val io = IO(new Bundle{
		val ENABLE 	= Input(Bool())
	  val pc 			= Input(UInt(64.W))
	  val inst 		= Input(UInt(64.W))
		val src 		= Input(UInt(64.W))
	  val dataout = Output(UInt(64.W))

    val TimeOver = Input(Bool())

	  val TimerInterrupt	= Output(Bool())
	  val EnvironmentCall = Output(Bool())

	  //difftest
	  val CSRState = Flipped(new DiffCSRStateIO)

  })
	
	//inst decode
	val CSRRW		= (io.inst === BitPat("b????????????_?????_001_?????_1110011"))
	val CSRRWI	= (io.inst === BitPat("b????????????_?????_101_?????_1110011"))
	val CSRRS		= (io.inst === BitPat("b????????????_?????_010_?????_1110011"))
	val CSRRSI	= (io.inst === BitPat("b????????????_?????_110_?????_1110011"))
	val CSRRC		= (io.inst === BitPat("b????????????_?????_011_?????_1110011"))
	val CSRRCI	= (io.inst === BitPat("b????????????_?????_111_?????_1110011"))
  val isCSR		= CSRRC || CSRRCI || CSRRS || CSRRSI || CSRRW || CSRRWI
  val isMRET  = (io.inst === BitPat("b001100000010_00000_000_00000_1110011"))
  val isECALL = (io.inst === BitPat("b000000000000_00000_000_00000_1110011"))
  val zimm 		= ZeroExt(io.inst(19,15), 64)
  val addr 		= io.inst(31,20)
  
  //regs
  val mcycle 	= RegInit(0.U(64.W))
  val mstatus = RegInit(0.U(64.W))
  val mtvec 	= RegInit(0.U(64.W))
  val mcause 	= RegInit(0.U(64.W))
  val mepc 		= RegInit(0.U(64.W))
  val mie     = RegInit(0.U(64.W))
  val mip     = RegInit(0.U(64.W))
  val mscratch    = RegInit(0.U(64.W))
  val medeleg     = RegInit(0.U(64.W))
  val mhartid     = RegInit(0.U(64.W))

  //addr
  val vis_mcycle	= addr === "hb00".U(12.W)
  val vis_mstatus	= addr === "h300".U(12.W)
  val vis_mtvec		= addr === "h305".U(12.W)
  val vis_mcause	= addr === "h342".U(12.W)
  val vis_mepc		= addr === "h341".U(12.W)
  val vis_mie     = addr === "h304".U(12.W)
  val vis_mip     = addr === "h344".U(12.W)
  val vis_mscratch    = addr === "h340".U(12.W)
  val vis_medeleg     = addr === "h302".U(12.W)
  val vis_mhartid     = addr === "hf14".U(12.W)

  //temp
  val temp_mcycle = Mux(isCSR && vis_mcycle, mcycle, 0.U)
  val temp_mstatus= Mux(isCSR && vis_mstatus,mstatus, 0.U)
  val temp_mtvec 	= Mux(isCSR && vis_mtvec,  mtvec, 0.U)
  val temp_mcause = Mux(isCSR && vis_mcause, mcause, 0.U)
  val temp_mepc 	= Mux(isCSR && vis_mepc,   mepc, 0.U)
  val temp_mie    = Mux(isCSR && vis_mie,    mie, 0.U)
  val temp_mip    = Mux(isCSR && vis_mip,    mip, 0.U)
  val temp_mscratch   = Mux(isCSR && vis_mscratch,   mscratch, 0.U)
  val temp_medeleg    = Mux(isCSR && vis_medeleg,    medeleg, 0.U)
  val temp_mhartid    = Mux(isCSR && vis_mhartid,    mhartid, 0.U)
  val temp = temp_mcycle | temp_mstatus | temp_mtvec | temp_mcause | temp_mepc | temp_mie | temp_mip |
             temp_mscratch | temp_medeleg | temp_mhartid

  //exception detect
  val MIE = mstatus(3)
  val MPIE = mstatus(7)
  val MPP = mstatus(12,11)
  val MTIE = mie(7)
  val MTIP = mip(7)
  val TIV = MIE && (MTIE && MTIP)       //TIV  means TimerInterruptValid
  val TIVR = RegEnable(TIV, false.B, io.ENABLE)  //TIVR means TimerInterruptValidRegNext
  // TIV  : ...00011111...
  // TIVR : ...00001111...
  // TI   : ...00010000...
  val TimerInterrupt = TIV && ~TIVR
  val EnvironmentCall = isECALL
  io.TimerInterrupt := TimerInterrupt
  io.EnvironmentCall := EnvironmentCall

  //write
  val data_rc		= SignExt(CSRRC.asUInt,  64)	& (temp & ~io.src)
  val data_rci	= SignExt(CSRRCI.asUInt, 64)	& (temp & ~zimm)
  val data_rs		= SignExt(CSRRS.asUInt,  64)	& (temp | io.src)
  val data_rsi	= SignExt(CSRRSI.asUInt, 64)	& (temp | zimm)
  val data_rw		= SignExt(CSRRW.asUInt,  64)	& (io.src)
  val data_rwi	= SignExt(CSRRWI.asUInt, 64)	& (zimm)
  val wdata 		= data_rc | data_rci | data_rs | data_rsi | data_rw | data_rwi

  /*  
    SD = mstatus(63)
    FS = mstatus(16,15)
    XS = mstatus(14,13)
    SD = ((FS==11) or (XS==11))
  */
  val SD = wdata(16,15)===3.U || wdata(14,13)===3.U

  val nop3 = 0.U(3.W)

  //mie := 0         mpie := mie       mpp := 3
  val Trap = Cat(mstatus(63,13), 3.U(2.W), nop3, MIE, nop3, 0.U(1.W), nop3)

  //mie := mpie      mpp := 0
  val Return = Cat(mstatus(63,13), 0.U(2.W), nop3, 1.U(1.W), nop3, MPIE, nop3)

  when(isCSR && vis_mcycle) {
    mcycle := wdata 
  }.otherwise { 
    mcycle := mcycle + 1.U 
  }

  when(io.ENABLE){

  	when(TimerInterrupt || EnvironmentCall) { 
  		mstatus := Trap
    }.elsewhen(isMRET) {
      mstatus := Return
  	}.elsewhen(isCSR && vis_mstatus) { 
      mstatus := Cat(SD.asUInt, wdata(62,0))
  	}

		when(isCSR && vis_mtvec) { mtvec := wdata }

    when(isCSR && vis_mie) { mie := wdata }

    when(isCSR && vis_mip) {
      mip := wdata 
    }.otherwise {
      mip := Mux(io.TimeOver, "h0000_0000_0000_0080".U, 0.U) 
    }

		when(TimerInterrupt) { 
			mcause := "h8000_0000_0000_0007".U(64.W) 
		}.elsewhen(EnvironmentCall) { 
			mcause := "h0000_0000_0000_000b".U(64.W) 
		}.elsewhen(isCSR && vis_mcause) { 
			mcause := wdata 
		}

		when(TimerInterrupt || EnvironmentCall) { 
			mepc := io.pc 
		}.elsewhen(isCSR && vis_mepc) { 
			mepc := wdata 
		}

    when(isCSR && vis_mscratch) { mscratch := wdata }
    when(isCSR && vis_medeleg) { medeleg := wdata }
    when(isCSR && vis_mhartid) { mhartid := wdata }

  }
  
  //read
  io.dataout := temp

  //debug
  val HIT = vis_mcycle || vis_mstatus || vis_mtvec || vis_mcause || vis_mepc || vis_mie || vis_mip
  val MissCSR = isCSR && ~HIT

	//difftest
  io.CSRState.clock := clock
  io.CSRState.coreid := 0.U
  io.CSRState.mstatus := mstatus
  io.CSRState.mcause := mcause
  io.CSRState.mepc := mepc
  io.CSRState.sstatus := mstatus & "h8000_0003_000d_e122".U
  io.CSRState.scause := 0.U
  io.CSRState.sepc := 0.U
  io.CSRState.satp := 0.U
  io.CSRState.mip := 0.U
  io.CSRState.mie := mie
  io.CSRState.mscratch := mscratch
  io.CSRState.sscratch := 0.U
  io.CSRState.mideleg := 0.U
  io.CSRState.medeleg := medeleg
  io.CSRState.mtval := 0.U
  io.CSRState.stval := 0.U
  io.CSRState.mtvec := mtvec
  io.CSRState.stvec := 0.U
  io.CSRState.priviledgeMode := 3.U

}